import importlib.util
from pathlib import Path

import numpy as np
import pandas as pd
from linearmodels.iv import IV2SLS


def load_shiftshare_module(root: Path):
    path = root / "analysis" / "19_iv_shiftshare_rollout_absorb.py"
    spec = importlib.util.spec_from_file_location("shiftshare", path)
    mod = importlib.util.module_from_spec(spec)
    assert spec.loader is not None
    spec.loader.exec_module(mod)
    return mod


def fit_iv_interaction_absorb_weighted(mod, df: pd.DataFrame, y: str, controls: list[str], weight_col: str) -> IV2SLS:
    cols = [y, "netusoft", "edu_high", "z_ss_cov30", "z_ss_cov100", "z_ss_fttp", "ct", "region", weight_col, *controls]
    d = df[cols].dropna().copy()
    if d.empty:
        raise ValueError(f"No rows after dropna for outcome={y}")

    w = pd.to_numeric(d[weight_col], errors="coerce")
    w = w.where(np.isfinite(w) & (w > 0))
    d = d.loc[w.notna()].copy()
    w = w.loc[d.index]
    if d.empty:
        raise ValueError(f"No rows after weight filtering for outcome={y}")

    d["netusoft_x_edu"] = d["netusoft"] * d["edu_high"]
    d["z_ss_cov30_x_edu"] = d["z_ss_cov30"] * d["edu_high"]
    d["z_ss_cov100_x_edu"] = d["z_ss_cov100"] * d["edu_high"]
    d["z_ss_fttp_x_edu"] = d["z_ss_fttp"] * d["edu_high"]

    model_cols = [
        y,
        "netusoft",
        "netusoft_x_edu",
        "z_ss_cov30",
        "z_ss_cov100",
        "z_ss_fttp",
        "z_ss_cov30_x_edu",
        "z_ss_cov100_x_edu",
        "z_ss_fttp_x_edu",
        *controls,
        "edu_high",
    ]
    r = mod.absorb_two_way(d, model_cols, fe_a="region", fe_b="ct", weights=w)

    sw = np.sqrt(w.to_numpy(dtype=float))
    sw_s = pd.Series(sw, index=r.index, name="sqrt_w")

    y_r = (r[y] * sw_s).rename(y)
    exog = (r[controls + ["edu_high"]].multiply(sw_s, axis=0)).copy()
    endog = (r[["netusoft", "netusoft_x_edu"]].multiply(sw_s, axis=0)).copy()
    instr = (
        r[
            [
                "z_ss_cov30",
                "z_ss_cov100",
                "z_ss_fttp",
                "z_ss_cov30_x_edu",
                "z_ss_cov100_x_edu",
                "z_ss_fttp_x_edu",
            ]
        ]
        .multiply(sw_s, axis=0)
        .copy()
    )

    res = IV2SLS(y_r, exog=exog, endog=endog, instruments=instr).fit(cov_type="clustered", clusters=d["region"])
    return res


def to_latex_tabular(df: pd.DataFrame) -> str:
    cols = [
        "label",
        "nobs",
        "coef_lowEdu",
        "se_lowEdu",
        "coef_highEdu",
        "coef_deltaHighMinusLow",
        "se_deltaHighMinusLow",
        "fsF_netusoft",
        "fsF_netusoft_x_edu",
        "AR_pvalue",
    ]
    header = [
        "Model",
        "N",
        "b(LowEdu)",
        "se(LowEdu)",
        "b(HighEdu)",
        "Delta(High-Low)",
        "se(Delta)",
        "F(net)",
        "F(net$\\times$edu)",
        "AR p",
    ]
    d = df[cols].copy()
    d["nobs"] = d["nobs"].map(lambda x: f"{x:.0f}" if pd.notna(x) else "")
    for c in ["coef_lowEdu", "se_lowEdu", "coef_highEdu", "coef_deltaHighMinusLow", "se_deltaHighMinusLow"]:
        d[c] = d[c].map(lambda x: f"{x:.4f}" if pd.notna(x) else "")
    for c in ["fsF_netusoft", "fsF_netusoft_x_edu"]:
        d[c] = d[c].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")
    d["AR_pvalue"] = d["AR_pvalue"].map(lambda x: f"{x:.4g}" if pd.notna(x) else "")

    out = []
    out.append("\\begin{tabular}{lrrrrrrrrr}")
    out.append("\\toprule")
    out.append(" & ".join(header) + " \\\\")
    out.append("\\midrule")
    for _, row in d.iterrows():
        out.append(" & ".join(str(row[c]) for c in cols) + " \\\\")
    out.append("\\bottomrule")
    out.append("\\end{tabular}")
    return "\n".join(out)


def main() -> None:
    root = Path(__file__).resolve().parents[1]
    out_dir = root / "outputs"
    out_dir.mkdir(parents=True, exist_ok=True)
    paper_tables = root / "paper_joc" / "tables"
    paper_tables.mkdir(parents=True, exist_ok=True)

    mod = load_shiftshare_module(root)
    df = mod.build_dataset(root)

    controls = ["agea", "gndr", "hinctnta"]
    weight_col = "anweight"

    rows = []
    for y, ylab in [("y_part_index", "Participation index"), ("y_vote", "Vote (yes/no)")]:
        # Unweighted
        res_u = mod.fit_iv_interaction_absorb(df, y, controls=controls)
        rows.append(mod.summarize_result(res_u, y=y, label=f"Unweighted 2SLS ({ylab})"))

        # Weighted
        res_w = fit_iv_interaction_absorb_weighted(mod, df, y=y, controls=controls, weight_col=weight_col)
        rows.append(mod.summarize_result(res_w, y=y, label=f"Weighted 2SLS ({ylab}); w={weight_col}"))

    tab = pd.DataFrame(rows)
    (out_dir / "iv_shiftshare_weighted_robustness.csv").write_text(tab.to_csv(index=False), encoding="utf-8")
    (out_dir / "iv_shiftshare_weighted_robustness.md").write_text(mod.to_markdown_table(tab) + "\n", encoding="utf-8")
    (paper_tables / "iv_shiftshare_weighted_robustness.tex").write_text(to_latex_tabular(tab) + "\n", encoding="utf-8")
    print("Wrote weighted robustness table to paper_joc/tables/iv_shiftshare_weighted_robustness.tex")


if __name__ == "__main__":
    main()
